Optimizing Stan code for a user-provided Gamma model

In [3]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import random
random.seed(1100038344)
import survivalstan
import numpy as np
import pandas as pd
from stancache import stancache
from matplotlib import pyplot as plt
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
/home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/Cython/Distutils/old_build_ext.py:30: UserWarning: Cython.Distutils.old_build_ext does not properly handle dependencies and is deprecated.
  "Cython.Distutils.old_build_ext does not properly handle dependencies "
/home/jacquelineburos/.local/lib/python3.5/site-packages/IPython/html.py:14: ShimWarning: The `IPython.html` package has been deprecated. You should import from `notebook` instead. `IPython.html.widgets` has moved to `ipywidgets`.
  "`IPython.html.widgets` has moved to `ipywidgets`.", ShimWarning)
INFO:stancache.seed:Setting seed to 1245502385

Simulate survival data

In order to demonstrate the use of this model, we will first simulate some survival data using survivalstan.sim.sim_data_exp_correlated. As the name implies, this function simulates data assuming a constant hazard throughout the follow-up time period, which is consistent with the Exponential survival function.

This function includes two simulated covariates by default (age and sex). We also simulate a situation where hazard is a function of the simulated value for sex.

We also center the age variable since this will make it easier to interpret estimates of the baseline hazard.

In [4]:
d = stancache.cached(
    survivalstan.sim.sim_data_exp_correlated,
    N=100,
    censor_time=20,
    rate_form='1 + sex',
    rate_coefs=[-3, 0.5],
)
d['age_centered'] = d['age'] - d['age'].mean()
INFO:stancache.stancache:sim_data_exp_correlated: cache_filename set to sim_data_exp_correlated.cached.N_100.censor_time_20.rate_coefs_54462717316.rate_form_1 + sex.pkl
INFO:stancache.stancache:sim_data_exp_correlated: Loading result from cache

*Aside: In order to make this a more reproducible example, this code is using a file-caching function stancache.cached to wrap a function call to survivalstan.sim.sim_data_exp_correlated. *

Explore simulated data

Here is what these data look like - this is per-subject or time-to-event form:

In [5]:
d.head()
Out[5]:
age sex rate true_t t event index age_centered
0 59 male 0.082085 20.948771 20.000000 False 0 4.18
1 58 male 0.082085 12.827519 12.827519 True 1 3.18
2 61 female 0.049787 27.018886 20.000000 False 2 6.18
3 57 female 0.049787 62.220296 20.000000 False 3 2.18
4 55 male 0.082085 10.462045 10.462045 True 4 0.18

It’s not that obvious from the field names, but in this example “subjects” are indexed by the field ``index``.

We can plot these data using lifelines, or the rudimentary plotting functions provided by survivalstan.

In [6]:
survivalstan.utils.plot_observed_survival(df=d[d['sex']=='female'], event_col='event', time_col='t', label='female')
survivalstan.utils.plot_observed_survival(df=d[d['sex']=='male'], event_col='event', time_col='t', label='male')
plt.legend()
Out[6]:
<matplotlib.legend.Legend at 0x7f64d8063978>
../_images/examples_Test_new_gamma_survival_model_with_simulated_data_11_1.png

model1: original spec

In [7]:
model_code = '''
functions {
  // Defines the log survival
  vector log_S (vector t, real shape, vector rate) {
    vector[num_elements(t)] log_S;
    for (i in 1:num_elements(t)) {
      log_S[i] = gamma_lccdf(t[i]|shape,rate[i]);
    }
    return log_S;
  }

  // Defines the log hazard
  vector log_h (vector t, real shape, vector rate) {
    vector[num_elements(t)] log_h;
    vector[num_elements(t)] ls;
    ls = log_S(t,shape,rate);
    for (i in 1:num_elements(t)) {
      log_h[i] = gamma_lpdf(t[i]|shape,rate[i]) - ls[i];
    }
    return log_h;
  }

  // Defines the sampling distribution
  real surv_gamma_lpdf (vector t, vector d, real shape, vector rate) {
    vector[num_elements(t)] log_lik;
    real prob;
    log_lik = d .* log_h(t,shape,rate) + log_S(t,shape,rate);
    prob = sum(log_lik);
    return prob;
  }
}

data {
  int N;                            // number of observations
  vector<lower=0>[N] y;             // observed times
  vector<lower=0,upper=1>[N] event;     // censoring indicator (1=observed, 0=censored)
  int M;                            // number of covariates
  matrix[N, M] x;                    // matrix of covariates (with n rows and H columns)
}

parameters {
  vector[M] beta;         // Coefficients in the linear predictor (including intercept)
  real<lower=0> alpha;    // shape parameter
}

transformed parameters {
  vector[N] linpred;
  vector[N] mu;
  linpred = x*beta;
  for (i in 1:N) {
    mu[i] = exp(linpred[i]);
  }
}

model {
  alpha ~ gamma(0.01,0.01);
  beta ~ normal(0,5);
  y ~ surv_gamma(event, alpha, mu);
}
'''

Now, we are ready to fit our model using survivalstan.fit_stan_survival_model.

We pass a few parameters to the fit function, many of which are required. See ?survivalstan.fit_stan_survival_model for details.

Similar to what we did above, we are asking survivalstan to cache this model fit object. See stancache for more details on how this works. Also, if you didn’t want to use the cache, you could omit the parameter FIT_FUN and survivalstan would use the standard pystan functionality.

In [8]:
testfit = survivalstan.fit_stan_survival_model(
    model_cohort = 'model 1',
    model_code = model_code,
    df = d,
    time_col = 't',
    event_col = 'event',
    formula = '~ age_centered + sex',
    iter = 5000,
    chains = 4,
    seed = 9001,
    FIT_FUN = stancache.cached_stan_fit,
    drop_intercept = False,
    )

INFO:stancache.stancache:Step 1: Get compiled model code, possibly from cache
INFO:stancache.stancache:StanModel: cache_filename set to anon_model.cython_0_25_1.model_code_14429915565770599621.pystan_2_12_0_0.stanmodel.pkl
INFO:stancache.stancache:StanModel: Loading result from cache
INFO:stancache.stancache:Step 2: Get posterior draws from model, possibly from cache
INFO:stancache.stancache:sampling: cache_filename set to anon_model.cython_0_25_1.model_code_14429915565770599621.pystan_2_12_0_0.stanfit.chains_4.data_25476010973.iter_5000.seed_9001.pkl
INFO:stancache.stancache:sampling: Loading result from cache
In [9]:
# 0:00:40.518775 elapsed
In [10]:
survivalstan.utils.print_stan_summary([testfit], pars=['lp__', 'alpha', 'beta'])
               mean   se_mean        sd        2.5%         50%       97.5%      Rhat
lp__    -278.113380  0.021718  1.358348 -281.579969 -277.815032 -276.424942  1.000465
alpha      1.220129  0.002561  0.170488    0.913076    1.209632    1.581949  1.000245
beta[0]   -2.703902  0.003699  0.226779   -3.173168   -2.695839   -2.275841  1.000271
beta[1]    0.608266  0.002897  0.199169    0.221442    0.606373    1.003776  1.000226
beta[2]    0.006018  0.000183  0.014844   -0.023068    0.005959    0.035536  1.000075

model2: alternate version of surv_gamma_lpdf

In [11]:
model_code2 = '''
functions {
  // Defines the log survival
  real surv_gamma_lpdf (vector t, vector d, real shape, vector rate) {
    vector[num_elements(t)] log_lik;
    real prob;
    for (i in 1:num_elements(t)) {
      log_lik[i] = d[i] * (gamma_lpdf(t[i]|shape,rate[i]) - gamma_lccdf(t[i]|shape,rate[i]))
                      + gamma_lccdf(t[i]|shape,rate[i]);
    }
    prob = sum(log_lik);
    return prob;
  }
}
data {
  int N;                            // number of observations
  vector<lower=0>[N] y;             // observed times
  vector<lower=0,upper=1>[N] event;     // censoring indicator (1=observed, 0=censored)
  int M;                            // number of covariates
  matrix[N, M] x;                    // matrix of covariates (with n rows and H columns)
}
parameters {
  vector[M] beta;         // Coefficients in the linear predictor (including intercept)
  real<lower=0> alpha;    // shape parameter
}
transformed parameters {
  vector<lower=0>[N] mu;
  {
      vector[N] linpred;
      linpred = x*beta;
      mu = exp(linpred);
  }
}

model {
  alpha ~ gamma(0.01,0.01);
  beta ~ normal(0,5);
  y ~ surv_gamma(event, alpha, mu);
}
'''
In [12]:
testfit2 = survivalstan.fit_stan_survival_model(
    model_cohort = 'model 2',
    model_code = model_code2,
    df = d,
    time_col = 't',
    event_col = 'event',
    formula = '~ age_centered + sex',
    iter = 5000,
    chains = 4,
    seed = 9001,
    FIT_FUN = stancache.cached_stan_fit,
    drop_intercept = False,
    )

INFO:stancache.stancache:Step 1: Get compiled model code, possibly from cache
INFO:stancache.stancache:StanModel: cache_filename set to anon_model.cython_0_25_1.model_code_9177012762674257483.pystan_2_12_0_0.stanmodel.pkl
INFO:stancache.stancache:StanModel: Loading result from cache
INFO:stancache.stancache:Step 2: Get posterior draws from model, possibly from cache
INFO:stancache.stancache:sampling: cache_filename set to anon_model.cython_0_25_1.model_code_9177012762674257483.pystan_2_12_0_0.stanfit.chains_4.data_25476010973.iter_5000.seed_9001.pkl
INFO:stancache.stancache:sampling: Loading result from cache
In [13]:
# 0:00:21.081723 elapsed
In [14]:
survivalstan.utils.print_stan_summary([testfit2], pars=['lp__', 'alpha', 'beta'])
               mean   se_mean        sd        2.5%         50%       97.5%      Rhat
lp__    -278.122019  0.022342  1.367441 -281.523430 -277.808097 -276.440635  1.001051
alpha      1.218976  0.002590  0.171529    0.906313    1.208893    1.586105  1.000714
beta[0]   -2.704073  0.003788  0.228848   -3.187193   -2.693108   -2.287256  1.000886
beta[1]    0.604867  0.002814  0.201056    0.208092    0.605508    0.993031  1.000143
beta[2]    0.006629  0.000188  0.014733   -0.021872    0.006506    0.036200  1.000063

model3: use log_mix inside surv_gamma_lpdf

In [15]:
model_code3 = '''
functions {
  // Defines the log survival
  real surv_gamma_lpdf (vector t, vector d, real shape, vector rate) {
    vector[num_elements(t)] log_lik;
    real prob;
    for (i in 1:num_elements(t)) {
      log_lik[i] = log_mix(d[i], gamma_lpdf(t[i]|shape,rate[i]), gamma_lccdf(t[i]|shape,rate[i]));
    }
    prob = sum(log_lik);
    return prob;
  }
}
data {
  int N;                            // number of observations
  vector<lower=0>[N] y;             // observed times
  vector<lower=0,upper=1>[N] event; // censoring indicator (1=observed, 0=censored)
  int M;                            // number of covariates
  matrix[N, M] x;                   // matrix of covariates (with n rows and H columns)
}
parameters {
  vector[M] beta;         // Coefficients in the linear predictor (including intercept)
  real<lower=0> alpha;    // shape parameter
}
transformed parameters {
  vector[N] linpred;
  vector[N] mu;
  linpred = x*beta;
  mu = exp(linpred);
}

model {
  alpha ~ gamma(0.01,0.01);
  beta ~ normal(0,5);
  y ~ surv_gamma(event, alpha, mu);
}
'''
In [16]:
testfit3 = survivalstan.fit_stan_survival_model(
    model_cohort = 'model 3',
    model_code = model_code3,
    df = d,
    time_col = 't',
    event_col = 'event',
    formula = '~ age_centered + sex',
    iter = 5000,
    chains = 4,
    seed = 9001,
    FIT_FUN = stancache.cached_stan_fit,
    drop_intercept = False,
    )

INFO:stancache.stancache:Step 1: Get compiled model code, possibly from cache
INFO:stancache.stancache:StanModel: cache_filename set to anon_model.cython_0_25_1.model_code_1293841621968646714.pystan_2_12_0_0.stanmodel.pkl
INFO:stancache.stancache:StanModel: Loading result from cache
INFO:stancache.stancache:Step 2: Get posterior draws from model, possibly from cache
INFO:stancache.stancache:sampling: cache_filename set to anon_model.cython_0_25_1.model_code_1293841621968646714.pystan_2_12_0_0.stanfit.chains_4.data_25476010973.iter_5000.seed_9001.pkl
INFO:stancache.stancache:sampling: Loading result from cache
In [17]:
#0:00:20.284146 elapsed
In [18]:
survivalstan.utils.print_stan_summary([testfit3], pars=['lp__', 'alpha', 'beta'])
               mean   se_mean        sd        2.5%         50%       97.5%      Rhat
lp__    -278.092768  0.021055  1.344719 -281.491359 -277.783232 -276.425885  1.000812
alpha      1.216429  0.002640  0.166944    0.905732    1.207800    1.576275  1.000432
beta[0]   -2.708385  0.003770  0.220966   -3.173480   -2.697460   -2.307317  1.000622
beta[1]    0.609834  0.002799  0.199563    0.223234    0.609396    1.008451  1.000295
beta[2]    0.006028  0.000180  0.014812   -0.022443    0.006036    0.036197  1.000265

model4: vectorize surv_gamma_lpdf

In [40]:
model_code4 = '''
functions {
    int count_value(vector a, real val) {
        int s;
        s = 0;
        for (i in 1:num_elements(a))
            if (a[i] == val)
                s = s + 1;
        return s;
    }

  // Defines the log survival
  real surv_gamma_lpdf (vector t, vector d, real shape, vector rate, int num_cens, int num_obs) {
    vector[2] log_lik;
    int idx_obs[num_obs];
    int idx_cens[num_cens];
    real prob;
    int i_cens;
    int i_obs;
    i_cens = 1;
    i_obs = 1;
    for (i in 1:num_elements(t)) {
        if (d[i] == 1) {
            idx_obs[i_obs] = i;
            i_obs = i_obs+1;
        }
        else {
            idx_cens[i_cens] = i;
            i_cens = i_cens+1;
        }
    }
    print(idx_obs);
    log_lik[1] = gamma_lpdf(t[idx_obs] | shape, rate[idx_obs]);
    log_lik[2] = gamma_lccdf(t[idx_cens] | shape, rate[idx_cens]);
    prob = sum(log_lik);
    return prob;
  }
}
data {
  int N;                            // number of observations
  vector<lower=0>[N] y;             // observed times
  vector<lower=0,upper=1>[N] event; // censoring indicator (1=observed, 0=censored)
  int M;                            // number of covariates
  matrix[N, M] x;                   // matrix of covariates (with n rows and H columns)
}
transformed data {
  int num_cens;
  int num_obs;
  num_obs = count_value(event, 1);
  num_cens = N - num_obs;
}
parameters {
  vector[M] beta;         // Coefficients in the linear predictor (including intercept)
  real<lower=0> alpha;    // shape parameter
}
transformed parameters {
  vector[N] linpred;
  vector[N] mu;
  linpred = x*beta;
  mu = exp(linpred);
}
model {
  alpha ~ gamma(0.01,0.01);
  beta ~ normal(0,5);
  y ~ surv_gamma(event, alpha, mu, num_cens, num_obs);
}
'''
In [41]:
testfit4 = survivalstan.fit_stan_survival_model(
    model_cohort = 'model 4',
    model_code = model_code4,
    df = d,
    time_col = 't',
    event_col = 'event',
    formula = '~ age_centered + sex',
    iter = 5000,
    chains = 4,
    seed = 9001,
    FIT_FUN = stancache.cached_stan_fit,
    drop_intercept = False,
    )

INFO:stancache.stancache:Step 1: Get compiled model code, possibly from cache
INFO:stancache.stancache:StanModel: cache_filename set to anon_model.cython_0_25_1.model_code_16881928540873162731.pystan_2_12_0_0.stanmodel.pkl
INFO:stancache.stancache:StanModel: Starting execution
INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_b18a6495e568fcff90662e16a3d2aa85 NOW.
INFO:stancache.stancache:StanModel: Execution completed (0:01:09.439292 elapsed)
INFO:stancache.stancache:StanModel: Saving results to cache
INFO:stancache.stancache:Step 2: Get posterior draws from model, possibly from cache
INFO:stancache.stancache:sampling: cache_filename set to anon_model.cython_0_25_1.model_code_16881928540873162731.pystan_2_12_0_0.stanfit.chains_4.data_25476010973.iter_5000.seed_9001.pkl
INFO:stancache.stancache:sampling: Starting execution
INFO:stancache.stancache:sampling: Execution completed (0:00:06.245552 elapsed)
INFO:stancache.stancache:sampling: Saving results to cache
/home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/stancache/stancache.py:284: UserWarning: Pickling fit objects is an experimental feature!
The relevant StanModel instance must be pickled along with this fit object.
When unpickling the StanModel must be unpickled first.
  pickle.dump(res, open(cache_filepath, 'wb'), pickle.HIGHEST_PROTOCOL)
In [42]:
# 0:00:06.245552 elapsed
In [43]:
survivalstan.utils.print_stan_summary([testfit4], pars=['lp__', 'alpha', 'beta'])
               mean   se_mean        sd        2.5%         50%       97.5%      Rhat
lp__    -278.097416  0.020380  1.353386 -281.543138 -277.782205 -276.445123  0.999867
alpha      1.216149  0.002525  0.169578    0.913312    1.206908    1.574024  1.000702
beta[0]   -2.708518  0.003627  0.228658   -3.188661   -2.700493   -2.278788  1.000723
beta[1]    0.610593  0.002824  0.201974    0.224438    0.607700    1.030164  1.000744
beta[2]    0.006211  0.000171  0.014493   -0.022229    0.006113    0.035095  1.000222

compare coefficient estimates for each model spec

In [44]:
survivalstan.utils.plot_coefs([testfit, testfit2, testfit3, testfit4])
../_images/examples_Test_new_gamma_survival_model_with_simulated_data_34_0.png
In [ ]: